import wandb
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from tenacity import retry, wait_exponential, stop_after_attempt

@retry(wait=wait_exponential(multiplier=1, min=2, max=20), stop=stop_after_attempt(5))
def get_wandb_runs(api, entity, project_name):
    return api.runs(f"{entity}/{project_name}")

def plot_metric_from_wandb(project_name, metric, ranking_metric, colors, labels, title, xlabel, ylabel, log_every=5, 
                            start_idx=0, end_idx=None, 
                            figsize=(10, 6), title_fontsize=20, label_fontsize=16, legend_fontsize=14, tick_fontsize=12,
                            save_directory='/directory/for/saving_plots', 
                            wandb_project='Extract_Results', wandb_entity='team-name', run_name='custom_run',
                            y_min=None, y_max=None, line_thickness=2, markers=None, max_markers=20, 
                            x_logarithm=False, y_logarithm=False, manual_ticks=False, 
                            title_color='cyan'):
    '''
    Plotting code for neat visualization of hyperparameter sensity over a hyperparameter sweep. 
    Requires an initial sweep logged to wandb, along with project name. 
    Ranks performance metric based on ranking metric, color-codes, then re-plots and saves to wandb.
    '''
    try:
        os.makedirs(save_directory, exist_ok=True)
        
        # Initialize the API
        api = wandb.Api(timeout=5000)
        
        # Retrieve runs from the specified project with retries
        runs = get_wandb_runs(api, wandb_entity, project_name)
        
        # Filter out runs that don't have the required metrics
        valid_runs = []
        for run in runs:
            if ranking_metric in run.summary:
                valid_runs.append(run)
        
        # Sort runs based on the final test loss value
        valid_runs.sort(key=lambda x: x.summary[ranking_metric], reverse=True)
        
        # Download all required data in one go
        all_run_histories = []
        for i, run in enumerate(valid_runs):
            print(f"Downloading history for run {i}")
            try:
                history = run.history(keys=[metric])
                all_run_histories.append((i, history))
            except Exception as e:
                print(f"Failed to retrieve history for run {i}: {e}")
        
        # Generate colors based on sorted order
        num_runs = len(valid_runs)
        color_map = plt.colormaps.get_cmap('viridis')
        
        plt.figure(figsize=figsize)
        
        for i, history in all_run_histories:
            if metric in history:
                data_array = np.array(history[metric])
                
                if end_idx is None:
                    end_idx = len(data_array)
                
                x = np.arange(start_idx, end_idx)
                
                data_array = data_array[start_idx:end_idx]

                color = color_map(i / num_runs)
                marker = markers[i % len(markers)] if markers is not None else None
                markevery_indices = np.linspace(0, len(x) - 1, max_markers, dtype=int)
                plt.plot(x * log_every, data_array, color=color, 
                         label=labels[i % len(labels)], linewidth=line_thickness, 
                         marker=marker, markevery=markevery_indices, alpha=0.2)

                if x_logarithm:
                    plt.xscale('log')
                if y_logarithm:
                    plt.yscale('log')
        
        # Organize plot
        plt.title(title, fontsize=title_fontsize, color=title_color)
        plt.xlabel(xlabel, fontsize=label_fontsize)
        plt.ylabel(ylabel, fontsize=label_fontsize)
        plt.grid(True)
        plt.tick_params(axis='both', which='both', labelsize=tick_fontsize)

        # Set axis limits
        if y_min is not None and y_max is not None:
            plt.ylim(y_min, y_max)
        plt.gca().autoscale(enable=True, axis='x', tight=True)

        if manual_ticks:
            ax = plt.gca()
            ax.xaxis.set_major_formatter(plt.ScalarFormatter())
            ax.yaxis.set_major_formatter(plt.ScalarFormatter())
            ax.xaxis.set_ticks([30, 40, 60, 100, 200])
            ax.yaxis.set_ticks([2, 3, 4])
        plt.tight_layout()
        
        pdf_file = os.path.join(save_directory, 'wandb_plot.pdf')
        plt.savefig(pdf_file, format='pdf')
        plt.show()

        # Initialize Weights and Biases and upload the plot as an artifact
        os.environ['WANDB_DIR'] = '/directory/for/wandb_log/'
        wandb.init(project=wandb_project, entity=wandb_entity, name=run_name)
        artifact = wandb.Artifact('plot_pdf', type='report')
        artifact.add_file(pdf_file)
        wandb.log_artifact(artifact)
        print("Plot successfully saved and uploaded to Weights and Biases.")
    
    except Exception as e:
        print(f"An error occurred: {e}")

# Project name (on wandb)
project_name="project name here"

# Plotted metric
metric="test accuracy"

# Ranking metric (Rank colors based on ranking_metric)
ranking_metric = "test loss"

plot_metric_from_wandb(
    project_name=project_name, 
    metric=metric, 
    ranking_metric=ranking_metric,
    colors=None,  # Colors will be generated dynamically
    labels=["Run 1", "Run 2"], 
    title="Direct Joint Adap.", 
    xlabel="Communication Rounds", 
    ylabel="Test Accuracy", 
    title_color = 'green',
    log_every=5, 
    start_idx=0, 
    end_idx=100, 
    figsize=(7, 6), 
    title_fontsize=43,
    label_fontsize=33, 
    legend_fontsize=18, 
    tick_fontsize=18, 
    wandb_entity='some-agents', 
    run_name="custom_run", 
    line_thickness=2, 
    markers=['d'], 
    max_markers=12, 
    x_logarithm=False, 
    y_logarithm=False, 
    manual_ticks=False,
    y_min=0.0, 
    y_max=0.225
)